{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# python notebook for Make Your Own Neural Network\n",
    "# code for a 3-layer neural network, and code for learning the MNIST dataset\n",
    "# this version trains using the MNIST dataset, then tests on our own images\n",
    "# (c) Tariq Rashid, 2016\n",
    "# license is GPLv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy\n",
    "# scipy.special for the sigmoid function expit()\n",
    "import scipy.special\n",
    "# library for plotting arrays\n",
    "import matplotlib.pyplot\n",
    "# ensure the plots are inside this notebook, not an external window\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# helper to load data from PNG image files\n",
    "import scipy.misc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# neural network class definition\n",
    "class neuralNetwork:\n",
    "    \n",
    "    \n",
    "    # initialise the neural network\n",
    "    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):\n",
    "        # set number of nodes in each input, hidden, output layer\n",
    "        self.inodes = inputnodes\n",
    "        self.hnodes = hiddennodes\n",
    "        self.onodes = outputnodes\n",
    "        \n",
    "        # link weight matrices, wih and who\n",
    "        # weights inside the arrays are w_i_j, where link is from node i to node j in the next layer\n",
    "        # w11 w21\n",
    "        # w12 w22 etc \n",
    "        self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))\n",
    "        self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))\n",
    "\n",
    "        # learning rate\n",
    "        self.lr = learningrate\n",
    "        \n",
    "        # activation function is the sigmoid function\n",
    "        self.activation_function = lambda x: scipy.special.expit(x)\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # train the neural network\n",
    "    def train(self, inputs_list, targets_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        targets = numpy.array(targets_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        # output layer error is the (target - actual)\n",
    "        output_errors = targets - final_outputs\n",
    "        # hidden layer error is the output_errors, split by weights, recombined at hidden nodes\n",
    "        hidden_errors = numpy.dot(self.who.T, output_errors) \n",
    "        \n",
    "        # update the weights for the links between the hidden and output layers\n",
    "        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))\n",
    "        \n",
    "        # update the weights for the links between the input and hidden layers\n",
    "        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # query the neural network\n",
    "    def query(self, inputs_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        return final_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# number of input, hidden and output nodes\n",
    "input_nodes = 784\n",
    "hidden_nodes = 200\n",
    "output_nodes = 10\n",
    "\n",
    "# learning rate\n",
    "learning_rate = 0.1\n",
    "\n",
    "# create instance of neural network\n",
    "n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# load the mnist training data CSV file into a list\n",
    "training_data_file = open(\"mnist_dataset/mnist_train.csv\", 'r')\n",
    "training_data_list = training_data_file.readlines()\n",
    "training_data_file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# train the neural network\n",
    "\n",
    "# epochs is the number of times the training data set is used for training\n",
    "epochs = 10\n",
    "\n",
    "for e in range(epochs):\n",
    "    # go through all records in the training data set\n",
    "    for record in training_data_list:\n",
    "        # split the record by the ',' commas\n",
    "        all_values = record.split(',')\n",
    "        # scale and shift the inputs\n",
    "        inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01\n",
    "        # create the target output values (all 0.01, except the desired label which is 0.99)\n",
    "        targets = numpy.zeros(output_nodes) + 0.01\n",
    "        # all_values[0] is the target label for this record\n",
    "        targets[int(all_values[0])] = 0.99\n",
    "        n.train(inputs, targets)\n",
    "        pass\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "test with our own image "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading ... my_own_images/2828_my_own_image.png\n",
      "min =  0.01\n",
      "max =  1.0\n",
      "[[ 0.0053985 ]\n",
      " [ 0.01424362]\n",
      " [ 0.00603518]\n",
      " [ 0.66770827]\n",
      " [ 0.01524452]\n",
      " [ 0.00302715]\n",
      " [ 0.05434234]\n",
      " [ 0.00856272]\n",
      " [ 0.13408102]\n",
      " [ 0.02900324]]\n",
      "network says  3\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD8RJREFUeJzt3X+M1HV+x/HXmx+VLASBi0IUb6/GYC9N1NjUWDFxTq34\n44x4GsW7GLHVaHJakibmPDFhMWK4mmBs4sUocAGjATkj4iVePeJNCFT80WILHosXKnsosBKxpsTw\nY7vv/rHDdnfd+XyH/c6P7+77+UgIw7xmdj7M7mu/M/P5fr8fc3cBiGVcqwcAoPkoPhAQxQcCovhA\nQBQfCIjiAwHlKr6ZXW9mnWb2iZn9rF6DAtBYNtJ5fDMbJ+kTSddIOiDpA0kL3L1zyO3YUQBoEXe3\n4a7Ps8W/TNIf3b3L3U9KWifplioP3v9nyZIlg/5dtD+Mb+yOr8hja8T4UvIU/1xJ+wf8+7PKdQAK\njg/3gIAm5Ljv55K+O+DfsyvXfUtHR0f/5WnTpuV4yMYrlUqtHkIS4xu5Io9Nyj++crmscrlc023z\nfLg3XtIe9X24d1DS+5LucvfdQ27nI30MACNnZvIqH+6NeIvv7v9rZg9Jelt9bxlWDS09gGIa8Ra/\n5gdgiw+0RGqLz4d7QEAUHwiI4gMBUXwgIIoPBETxgYAoPhAQxQcCyrOvPmqQtfPSoUOHkvnKlSuT\n+a5du5L5wYMHk/nx48eTudmw+3/UlGf93y+88MJkft111yXzW2+9NZm3tbUl88jY4gMBUXwgIIoP\nBETxgYAoPhAQxQcC4nj8nLL+b++9914yv+qqq5L5iRMnTntMUUydOjWZv/XWW8n8iiuuqOdwCofj\n8QEMQvGBgCg+EBDFBwKi+EBAFB8IiOIDATGPnyHvYbXt7e3JvKenJ5kvX748mS9YsCCZZ811jx8/\nPplnHZab0tvbm8y7urqS+UsvvZTMV6xYkczPOOOMZL5///5kPmPGjGRedMzjAxiE4gMBUXwgIIoP\nBETxgYAoPhAQxQcCyjWPb2b7JH0tqVfSSXe/bJjbFHoeP2tsWXPRt99+ezLfuHFjMn/hhReS+X33\n3ZfMI3vuueeS+cMPP5zM77333mS+atWqZJ5nH4dmSM3j5z2vfq+kkrt/lfPrAGiivC/1rQ5fA0CT\n5S2tS/qdmX1gZvfXY0AAGi/vS/257n7QzM5S3y+A3e6+deiNOjo6+i+XSiWVSqWcDwtgqHK5rHK5\nXNNtcxXf3Q9W/j5sZq9LukxSsvgAGmPoRnXp0qVVbzvil/pm1mZmUyqXJ0u6TlJ6BUcAhZBniz9T\n0utm5pWv87K7v12fYQFopPDH42eN7fHHH0/mTz31VDK//PLLk/mWLVuS+cSJE5P5WJb1vclac+CC\nCy5I5p9//nky7+zsTOZz5sxJ5q3G8fgABqH4QEAUHwiI4gMBUXwgIIoPBETxgYCYx88Y25VXXpnM\n9+7dm8w//vjjZJ517vaiH/PdSlnfuw0bNiTzO++8M5kvXrw4mT/55JPJvNWYxwcwCMUHAqL4QEAU\nHwiI4gMBUXwgIIoPBJT3nHtj3ubNm5N51nn329rakjnz9I3T3t6e6/6HDx+u00iKhy0+EBDFBwKi\n+EBAFB8IiOIDAVF8ICCKDwQUfh4/ax590qRJDf36qC7rePtjx44l8/vvz7eO67x583Ldv8jY4gMB\nUXwgIIoPBETxgYAoPhAQxQcCovhAQJnz+Ga2StIPJXW7+0WV66ZLWi+pXdI+SXe4+9cNHGfLMA/f\nOFnz9CdOnEjmt912WzLfuXNnMl+4cGEynz9/fjIfzWrZ4v9K0tA9GR6VtNndL5T0jqSf13tgABon\ns/juvlXSV0OuvkXSmsrlNZLG7q9GYAwa6Xv8s929W5Lc/ZCks+s3JACNVq999ZNv1jo6Ovovl0ol\nlUqlOj0sgFPK5bLK5XJNtx1p8bvNbKa7d5vZLElfpG48sPgAGmPoRnXp0qVVb1vrS32r/Dllk6SF\nlcv3SHrjdAYIoLUyi29mr0j6V0lzzOxPZnavpOWS/tbM9ki6pvJvAKOENXrtejPzRj8GGifre5da\nV+D48ePJ++7atSuZL1q0KJlv3749mV977bXJfNOmTck861wMRd/Hw8zk7sMOkj33gIAoPhAQxQcC\novhAQBQfCIjiAwFRfCCg8OfVL7qsefSTJ08m89deey2Zr1u3Lpl/+umnyfzAgQNVsyNHjiTvm3f/\njqzj8desWZPMR/s8fR5s8YGAKD4QEMUHAqL4QEAUHwiI4gMBUXwgIObxR7lly5Yl8yeeeKJJI6m/\nadOmJfOrr746mWedD2Dy5MmnPaaxgi0+EBDFBwKi+EBAFB8IiOIDAVF8ICCKDwTEefULLuu527Fj\nRzJfvjy91kneY87zfG+3bNmSzLu7u0f8tSVp+vTpyfzDDz9M5ueff36ux281zqsPYBCKDwRE8YGA\nKD4QEMUHAqL4QEAUHwgocx7fzFZJ+qGkbne/qHLdEkn3S/qicrPH3P23Ve7PPH4DjebntqenJ5nv\n3bs3ma9evTqZP/3008l8xowZuR4/63wBrZZ3Hv9XkuYNc/0Kd7+08mfY0gMopsziu/tWSV8NE43d\nZUaAMS7Pe/yHzOwjM1tpZmfWbUQAGm6k59z7paQn3N3N7ElJKyT9fbUbd3R09F8ulUoqlUojfFgA\n1ZTLZZXL5ZpuW9NBOmbWLunNUx/u1ZpVcj7ca6DR/Nzy4V5j1eMgHdOA9/RmNmtA9iNJu0Y+PADN\nlvlS38xekVSS9B0z+5OkJZJ+YGaXSOqVtE/SAw0cI4A643h8tEzen4ve3t5k/sgjjyTzZ555Jpmv\nXbs2md99993JvNU4Hh/AIBQfCIjiAwFRfCAgig8ERPGBgCg+ENBI99UHcst7Tv9x49LbrZtuuimZ\nZ83j79u373SHNGqwxQcCovhAQBQfCIjiAwFRfCAgig8ERPGBgJjHx5g1YUK+H++s4/1HM7b4QEAU\nHwiI4gMBUXwgIIoPBETxgYAoPhAQ8/gYs7Zt25br/uedd16dRlI8bPGBgCg+EBDFBwKi+EBAFB8I\niOIDAVF8ICDLWqPczGZLWitppqReSS+6+z+b2XRJ6yW1S9on6Q53/3qY+3veddARU9bPzYkTJ5L5\nOeeck8y//vpbP66DfPnll8n8zDPPTOatZmZy92EXL6hli98j6R/d/S8l/Y2kn5rZX0h6VNJmd79Q\n0juSfl6vAQNorMziu/shd/+ocvmopN2SZku6RdKays3WSJrfqEECqK/Teo9vZt+TdImk7ZJmunu3\n1PfLQdLZ9R4cgMaoeV99M5si6deSFrn7UTMb+gas6huyjo6O/sulUkmlUun0RgkgU7lcVrlcrum2\nmR/uSZKZTZD0G0lvufuzlet2Syq5e7eZzZL0e3f//jD35cM9jAgf7uWT98M9SVot6Q+nSl+xSdLC\nyuV7JL0x4hECaKrMl/pmNlfSTyTtNLMd6ntJ/5ikX0h61cz+TlKXpDsaOVAA9ZNZfHffJml8lfja\n+g6n/rJeLp48eTKZv/nmm8l83rx5yXzKlCnJPLKs701Wvnjx4mR+5MiRZP7ggw8m86lTpybz0Yw9\n94CAKD4QEMUHAqL4QEAUHwiI4gMBUXwgoJp22c31AC3eZTfrsd99991kPnfu3GR+8803J/NXX301\nmU+aNCmZj2VZ35v169cn87vuuiuZz5o1K5nv2bMnmY/2efx67LILYAyh+EBAFB8IiOIDAVF8ICCK\nDwRE8YGAws/jHzt2LJlfeumlybyzszOZ33jjjcn8+eefT+azZ89O5mbDTtM2Ter5zTrXwcsvv5zM\nH3jggWTe09OTzHfv3p3M58yZk8xb/dzmxTw+gEEoPhAQxQcCovhAQBQfCIjiAwFRfCCgMT+PnyVr\nbN98800yv/7665P51q1bT3tMA7W1tSXzrP0Ess4nMH58tSUTapNahmrlypXJ+3Z1dSXziRMnJvON\nGzcm8xtuuCGZj/Z5+izM4wMYhOIDAVF8ICCKDwRE8YGAKD4QUGbxzWy2mb1jZh+b2U4ze7hy/RIz\n+8zM/r3yJz2vBaAwMufxzWyWpFnu/pGZTZH0b5JukXSnpP9x9xUZ9y/0PH6WrLEfP348ma9duzaZ\nb9iwIZlv3749mR89ejSZt1LWPPnChQuT+bJly5J51nnzx/o8fZbUPP6ErDu7+yFJhyqXj5rZbknn\nnvradRslgKY5rff4ZvY9SZdIeq9y1UNm9pGZrTSzM+s8NgANUnPxKy/zfy1pkbsflfRLSee7+yXq\ne0WQfMkPoDgyX+pLkplNUF/pX3L3NyTJ3Q8PuMmLkt6sdv+Ojo7+y6VSSaVSaQRDBZBSLpdVLpdr\num1NxZe0WtIf3P3ZU1eY2azK+39J+pGkXdXuPLD4ABpj6EZ16dKlVW+bWXwzmyvpJ5J2mtkOSS7p\nMUk/NrNLJPVK2icpfUpUAIVRy6f62yQNd+zmb+s/HADNEP54/Lwa/X/LOnd81hrv+/fvT+Z557pT\nx/NffPHFyfueddZZuR47+jx9Fo7HBzAIxQcCovhAQBQfCIjiAwFRfCAgig8ExDx+wY3l5455+MZi\nHh/AIBQfCIjiAwE1vfi1Hi/cKowvnyKPr8hjk5o7Poo/BOPLp8jjK/LYpDFefACtR/GBgJoyj9/Q\nBwBQVbV5/IYXH0Dx8FIfCIjiAwE1rfhmdr2ZdZrZJ2b2s2Y9bq3MbJ+Z/YeZ7TCz9wswnlVm1m1m\n/znguulm9raZ7TGzf2nl6kVVxleYhVSHWez1HyrXF+I5bPVitE15j29m4yR9IukaSQckfSBpgbt3\nNvzBa2Rm/yXpr9z9q1aPRZLM7EpJRyWtdfeLKtf9QtKX7v5PlV+e09390QKNb4lqWEi1GRKLvd6r\nAjyHeRejzatZW/zLJP3R3bvc/aSkder7TxaJqUBvfdx9q6Shv4RukbSmcnmNpPlNHdQAVcYnFWQh\nVXc/5O4fVS4flbRb0mwV5DmsMr6mLUbbrB/0cyUNPM/zZ/r//2RRuKTfmdkHZnZ/qwdTxdnu3i31\nr2J8dovHM5zCLaQ6YLHX7ZJmFu05bMVitIXZwhXAXHe/VNKNkn5aeSlbdEWbiy3cQqrDLPY69Dlr\n6XPYqsVom1X8zyV9d8C/Z1euKwx3P1j5+7Ck19X39qRous1sptT/HvGLFo9nEHc/POCsKy9K+utW\njme4xV5VoOew2mK0zXgOm1X8DyRdYGbtZvZnkhZI2tSkx85kZm2V37wys8mSrlNiEdAmMg1+v7dJ\n0sLK5XskvTH0Dk02aHyVIp2SXEi1Sb612KuK9RwOuxjtgLxhz2HT9tyrTEs8q75fNqvcfXlTHrgG\nZvbn6tvKu/rWE3y51eMzs1cklSR9R1K3pCWSNkraIOk8SV2S7nD3/y7Q+H6gvveq/Qupnno/3YLx\nzZW0RdJO9X1fTy32+r6kV9Xi5zAxvh+rCc8hu+wCAfHhHhAQxQcCovhAQBQfCIjiAwFRfCAgig8E\nRPGBgP4PkdniMgOOskcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7feb67a27278>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# test the neural network withour own images\n",
    "\n",
    "# load image data from png files into an array\n",
    "print (\"loading ... my_own_images/2828_my_own_image.png\")\n",
    "img_array = scipy.misc.imread('my_own_images/2828_my_own_image.png', flatten=True)\n",
    "    \n",
    "# reshape from 28x28 to list of 784 values, invert values\n",
    "img_data  = 255.0 - img_array.reshape(784)\n",
    "    \n",
    "# then scale data to range from 0.01 to 1.0\n",
    "img_data = (img_data / 255.0 * 0.99) + 0.01\n",
    "print(\"min = \", numpy.min(img_data))\n",
    "print(\"max = \", numpy.max(img_data))\n",
    "\n",
    "# plot image\n",
    "matplotlib.pyplot.imshow(img_data.reshape(28,28), cmap='Greys', interpolation='None')\n",
    "\n",
    "# query the network\n",
    "outputs = n.query(img_data)\n",
    "print (outputs)\n",
    "\n",
    "# the index of the highest value corresponds to the label\n",
    "label = numpy.argmax(outputs)\n",
    "print(\"network says \", label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
